import torch
import pandas as pd
import numpy as np
import gc
import time
import matplotlib.pyplot as plt
import os
import joblib
from tqdm.auto import tqdm 
import mlflow

# from scipy.interpolate import interp1d
from mlflow.models import infer_signature
from math import pi, sqrt, exp
from torch import nn,Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchinfo import summary
from timm.scheduler import CosineLRScheduler
plt.style.use("ggplot")

from pyarrow.parquet import ParquetFile
import pyarrow as pa 
import ctypes

def normalize(y):
    mean = y[:,0].mean().item()
    std = y[:,0].std().item()
    y[:,0] = (y[:,0]-mean)/(std+1e-16)
    mean = y[:,1].mean().item()
    std = y[:,1].std().item()
    y[:,1] = (y[:,1]-mean)/(std+1e-16)
    return y

device = 'cuda' if torch.cuda.is_available() else 'cpu'


EPOCHS = 5
WARMUP_PROP = 0.2
BS = 1
WORKERS = 16
TRAIN_PROP = 0.8
max_chunk_size = 150000
if device=='cpu':
    torch.set_num_interop_threads(WORKERS)
    torch.set_num_threads(WORKERS)
    
class ResidualBiGRU(nn.Module):
    def __init__(self, hidden_size, n_layers=1, bidir=True):
        super(ResidualBiGRU, self).__init__()

        self.hidden_size = hidden_size
        self.n_layers = n_layers

        self.gru = nn.GRU(
            hidden_size,
            hidden_size,
            n_layers,
            batch_first=True,
            bidirectional=bidir,
        )
        dir_factor = 2 if bidir else 1
        self.fc1 = nn.Linear(
            hidden_size * dir_factor, hidden_size * dir_factor * 2
        )
        self.ln1 = nn.LayerNorm(hidden_size * dir_factor * 2)
        self.fc2 = nn.Linear(hidden_size * dir_factor * 2, hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)

    def forward(self, x, h=None):
        res, new_h = self.gru(x, h)
        # res.shape = (batch_size, sequence_size, 2*hidden_size)

        res = self.fc1(res)
        res = self.ln1(res)
        res = nn.functional.relu(res)

        res = self.fc2(res)
        res = self.ln2(res)
        res = nn.functional.relu(res)

        # skip connection
        res = res + x

        return res, new_h

class MultiResidualBiGRU(nn.Module):
    def __init__(self, input_size, hidden_size, out_size, n_layers, bidir=True):
        super(MultiResidualBiGRU, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.out_size = out_size
        self.n_layers = n_layers

        self.fc_in = nn.Linear(input_size, hidden_size)
        self.ln = nn.LayerNorm(hidden_size)
        self.res_bigrus = nn.ModuleList(
            [
                ResidualBiGRU(hidden_size, n_layers=1, bidir=bidir)
                for _ in range(n_layers)
            ]
        )
        self.fc_out = nn.Linear(hidden_size, out_size)

    def forward(self, x, h=None):
        # if we are at the beginning of a sequence (no hidden state)
        if h is None:
            # (re)initialize the hidden state
            h = [None for _ in range(self.n_layers)]

        x = self.fc_in(x)
        x = self.ln(x)
        x = nn.functional.relu(x)

        new_h = []
        for i, res_bigru in enumerate(self.res_bigrus):
            x, new_hi = res_bigru(x, h[i])
            new_h.append(new_hi)

        x = self.fc_out(x)
#         x = F.normalize(x,dim=0)
        return x, new_h  # log probabilities + hidden states

SIGMA = 720 #average length of day is 24*60*12 = 17280 for comparison
SAMPLE_FREQ = 12 # 1 obs per minute
class SleepDataset(Dataset):
    def __init__(
        self,
        file
    ):
        self.targets,self.data,self.ids = joblib.load(file)
            
    def downsample_seq_generate_features(self,feat, downsample_factor = SAMPLE_FREQ):
        # downsample data and generate features
        if len(feat)%SAMPLE_FREQ==0:
            feat = np.concatenate([feat,np.zeros(SAMPLE_FREQ-((len(feat))%SAMPLE_FREQ))+feat[-1]])
        feat = np.reshape(feat, (-1,SAMPLE_FREQ))
        feat_mean = np.mean(feat,1)
        feat_std = np.std(feat,1)
        feat_median = np.median(feat,1)
        feat_max = np.max(feat,1)
        feat_min = np.min(feat,1)

        return np.dstack([feat_mean,feat_std,feat_median,feat_max,feat_min])[0]
    def downsample_seq(self,feat, downsample_factor = SAMPLE_FREQ):
        # downsample data
        if len(feat)%SAMPLE_FREQ==0:
            feat = np.concatenate([feat,np.zeros(SAMPLE_FREQ-((len(feat))%SAMPLE_FREQ))+feat[-1]])
        feat = np.reshape(feat, (-1,SAMPLE_FREQ))
        feat_mean = np.mean(feat,1)
        return feat_mean
    
    def gauss(self,n=SIGMA,sigma=SIGMA*0.15):
        # guassian distribution function
        r = range(-int(n/2),int(n/2)+1)
        return [1 / (sigma * sqrt(2*pi)) * exp(-float(x)**2/(2*sigma**2)) for x in r]
    
    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        X = self.data[index][['anglez','enmo']]
        y = self.targets[index]
        
        # turn target inds into array
        target_guassian = np.zeros((len(X),2))
        for s,e in y:
            st1,st2 = max(0,s-SIGMA//2),s+SIGMA//2+1
            ed1,ed2 = e-SIGMA//2,min(len(X),e+SIGMA//2+1)
            target_guassian[st1:st2,0] = self.gauss()[st1-(s-SIGMA//2):]
            target_guassian[ed1:ed2,1] = self.gauss()[:SIGMA+1-((e+SIGMA//2+1)-ed2)]
            gc.collect()
        y = target_guassian
        gc.collect()
        X = np.concatenate([self.downsample_seq_generate_features(X.values[:,i],SAMPLE_FREQ) for i in range(X.shape[1])],-1)
        gc.collect()
        y = np.dstack([self.downsample_seq(y[:,i],SAMPLE_FREQ) for i in range(y.shape[1])])[0]
        gc.collect()
        y = normalize(torch.from_numpy(y))
        X = torch.from_numpy(X)
        return X, y
train_ds = SleepDataset('./train_data.pkl')



train_size = int(TRAIN_PROP * len(train_ds))
valid_size = len(train_ds) - train_size
indices = torch.randperm(len(train_ds))
train_sampler = SubsetRandomSampler(indices[:train_size])
valid_sampler = SubsetRandomSampler(
    indices[train_size : train_size + valid_size]
)



def evaluate(
    model: nn.Module, max_chunk_size: int, loader: DataLoader, device, criterion,epoch,best_valid_loss,dt
):
    model.eval()
    valid_loss = 0.0
    y_true_full = torch.FloatTensor([]).half()
    y_pred_full = torch.FloatTensor([]).half()
    for X_batch, y_batch in tqdm(loader, desc="Eval", unit="batch"):
        # as some of the sequences we are dealing with are pretty long, we
        # use a chunk-based approach
        y_batch = y_batch.to(device, non_blocking=True)
        pred = torch.zeros(y_batch.shape).to(device, non_blocking=True).half()

        # (re)initialize model's hidden state(s)
        h = None

        # number of chunks for this sequence (we assume batch size = 1)
        seq_len = X_batch.shape[1]
        for i in range(0, seq_len, max_chunk_size):
            X_chunk = X_batch[:, i : i + max_chunk_size].float().to(device, non_blocking=True)

            y_pred, h = model(X_chunk, h)
            h = [hi.detach() for hi in h]
            pred[:, i : i + max_chunk_size] = y_pred.half()
            del X_chunk
            gc.collect()
        loss = criterion(
            pred.float(),
            y_batch.float(),
        )
        valid_loss += loss.item()
        del pred,loss
        gc.collect()

    valid_loss /= len(loader)

    y_true_full = y_true_full.squeeze(0)
    y_pred_full = y_pred_full.squeeze(0)
    gc.collect()
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        mlflow.pytorch.log_model(model, "model")

    dt = time.time() - dt
    print(
        f"{epoch}/{EPOCHS} -- ",
        f"valid_loss = {valid_loss:.6f} -- ",
        f"time = {dt:.6f}s",
    )
    mlflow.log_metric("eval_loss", f"{valid_loss:2f}", step=epoch)

def train(train_loader,model,criterion,optimizer,scheduler,epoch):
    train_loss = 0.0
    n_tot_chunks = 0
    pbar = tqdm(
        train_loader, desc="Training", unit="batch"
    )
    model.train()
    for step,(X_batch, y_batch) in enumerate(pbar):
        y_batch = y_batch.to(device, non_blocking=True)
        pred = torch.zeros(y_batch.shape).to(device, non_blocking=True)
        optimizer.zero_grad()
        scheduler.step(step+train_size*epoch)
        h = None

        seq_len = X_batch.shape[1]
        for i in range(0, seq_len, max_chunk_size):
            X_chunk = X_batch[:, i : i + max_chunk_size].float()

            X_chunk = X_chunk.to(device, non_blocking=True)

            y_pred, h = model(X_chunk, h)
            h = [hi.detach() for hi in h]
            pred[:, i : i + max_chunk_size] = y_pred
            del X_chunk,y_pred

        loss = criterion(
            normalize(pred).float(),
            y_batch.float(),
        )
        loss.backward()
        train_loss += loss.item()
        n_tot_chunks+=1
        mlflow.log_metric("epoch_{}_step_training_loss".format(epoch), f"{(train_loss/n_tot_chunks):6f}", step=step)

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e-1)
        optimizer.step()
        del pred,loss,y_batch,X_batch,h
        gc.collect()
    train_loss /= len(train_loader)
    mlflow.log_metric("training_loss", f"{train_loss:6f}", step=epoch)

    del pbar
    gc.collect()
    ctypes.CDLL("libc.so.6").malloc_trim(0)


steps = train_size*EPOCHS
warmup_steps = int(steps*WARMUP_PROP)
model = MultiResidualBiGRU(input_size=10,hidden_size=64,out_size=2,n_layers=5).to(device)

learning_rate=1e-4
weight_decay=0.1
model=torch.nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,weight_decay = weight_decay)
scheduler = CosineLRScheduler(optimizer,t_initial= steps,warmup_t=warmup_steps, warmup_lr_init=1e-6,lr_min=2e-8,)
dt = time.time()
loss_function = torch.nn.MSELoss()
train_loader = DataLoader(
    train_ds,
    batch_size=BS,
    sampler=train_sampler,
    pin_memory=True,
    num_workers=20,
)
valid_loader = DataLoader(
    train_ds,
    batch_size=1,
    sampler=valid_sampler,
    pin_memory=True,
    num_workers=20,
)


mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("/mlflow-project-2")

mlflow.pytorch.autolog()




with mlflow.start_run(log_system_metrics=True) as run:
    params = {
        "epochs": EPOCHS,
        "batch_size": BS,
        "loss_function": loss_function.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "learning_rate":learning_rate,
        "weight_decay":weight_decay,
    }
    # Log training parameters.
    mlflow.log_params(params)
    # Log model summary.
    with open("model_summary.txt", "w") as f:
        f.write(str(summary(model)))
    mlflow.log_artifact("model_summary.txt")
    
    dt=time.time()
    best_valid_loss=np.inf
    for t in range(EPOCHS):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_loader=train_loader,model=model,criterion=loss_function,optimizer=optimizer,scheduler=scheduler,epoch=t)
        evaluate(model=model,max_chunk_size=max_chunk_size,loader=valid_loader,device=device,criterion=loss_function,epoch=t,best_valid_loss=best_valid_loss,dt=dt)
    # Save the trained model to MLflow.
    # mlflow.pytorch.log_model(model, "model")